package com.chatplus.application.config.ws.interceptor;

import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.http.HtmlUtil;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.domain.request.ws.ChatWebSocketRequest;
import com.chatplus.application.web.satoken.helper.LoginHelper;
import com.chatplus.application.web.satoken.model.LoginUser;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

import java.util.Map;
import java.util.Objects;

/**
 * WebSocket握手请求的拦截器
 */
@Component
public class PlusWebSocketInterceptor implements HandshakeInterceptor {

    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(PlusWebSocketInterceptor.class);
    /**
     * 握手前
     *
     * @param request    domain
     * @param response   response
     * @param wsHandler  wsHandler
     * @param attributes attributes
     * @return 是否握手成功
     */
    @Override
    public boolean beforeHandshake(@NotNull ServerHttpRequest request, @NotNull ServerHttpResponse response, @NotNull WebSocketHandler wsHandler, Map<String, Object> attributes) {
        // 提取 GET 参数
        String url = HtmlUtil.unescape(request.getURI().toString());
        UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(url).build();
        MultiValueMap<String, String> queryParams = uriComponents.getQueryParams();
        String path = uriComponents.getPath();
        if (StringUtils.isEmpty(path)) {
            LOGGER.message("WebSocket握手失败，path为空")
                    .context("path", path)
                    .context("queryParams", queryParams)
                    .info();
            return false;
        }
        String userIdStr = "";
        Long userId;
        ChatWebSocketRequest chatWebSocketRequest;
        switch (path) {
            case ChatWebSocketRequest.WS_URL_PATH:
                String token = queryParams.get("token").getFirst();
                userId = checkUserLogin(token);
                if (Objects.isNull(userId)) {
                    return false;
                }
                chatWebSocketRequest = ChatWebSocketRequest.builder()
                        .sessionId(queryParams.get("session_id").getFirst())
                        .roleId(Long.parseLong(queryParams.get("role_id").getFirst()))
                        .chatId(queryParams.get("chat_id").getFirst())
                        .modelId(Long.parseLong(queryParams.get("model_id").getFirst()))
                        .token(token)
                        .userId(userId)
                        .build();
                attributes.put(ChatWebSocketRequest.WS_URL_PATH, chatWebSocketRequest);
                return true;
            case ChatWebSocketRequest.MJ_URL_PATH:
                userIdStr = queryParams.get("user_id").getFirst();
                if (StringUtils.isEmpty(userIdStr)) {
                    return false;
                }
                chatWebSocketRequest = ChatWebSocketRequest.builder()
                        .userId(Long.parseLong(userIdStr))
                        .build();
                attributes.put(ChatWebSocketRequest.MJ_URL_PATH, chatWebSocketRequest);
                return true;
            case ChatWebSocketRequest.SD_URL_PATH:
                userIdStr = queryParams.get("user_id").getFirst();
                if (StringUtils.isEmpty(userIdStr)) {
                    return false;
                }
                chatWebSocketRequest = ChatWebSocketRequest.builder()
                        .userId(Long.parseLong(userIdStr))
                        .build();
                attributes.put(ChatWebSocketRequest.SD_URL_PATH, chatWebSocketRequest);
                return true;
            default:
                LOGGER.message("WebSocket握手失败，找不到对应的处理器")
                        .context("path", path)
                        .context("queryParams", queryParams)
                        .info();
                return false;
        }
    }

    /**
     * 握手后
     *
     * @param request   domain
     * @param response  response
     * @param wsHandler wsHandler
     * @param exception 异常
     */
    @Override
    public void afterHandshake(@NotNull ServerHttpRequest request, @NotNull ServerHttpResponse response, @NotNull WebSocketHandler wsHandler, Exception exception) {

    }

    private Long checkUserLogin(String token) {
        if (StringUtils.isEmpty(token)) {
            return null;
        }
        LoginUser loginUser = LoginHelper.getLoginUser(token);
        if (loginUser == null) {
            StpUtil.logoutByTokenValue(token);
            return null;
        }
        return loginUser.getUserId();
    }
}
